package com.twitter.finagle.mux.pushsession

import com.twitter.finagle.mux.{Handshake, Request, Response}
import com.twitter.finagle.mux.Handshake.{CanTinitMsg, Headers, TinitTag}
import com.twitter.finagle.{Mux, Service, Status}
import com.twitter.finagle.mux.transport.Message
import com.twitter.finagle.pushsession.{PushSession, RefPushSession}
import com.twitter.io.{Buf, ByteReader}
import com.twitter.logging.{Level, Logger}
import com.twitter.util._
import scala.util.control.NonFatal

/**
 * Session which negotiates Mux features
 *
 * Instances of the `MuxServerNegotiator` assume ownership of the provided `Service` and
 * `PushChannelHandle`. Upon successful completion of the negotiation process ownership
 * of these is transferred to the session generated by the `negotiate` function.
 *
 * Thread safety considerations
 * - `receive` is only intended to be called from within the serial executor associated with
 *   the provided `PushChannelHandle`.
 * - `negotiation` will be called from within the `serialExecutor`
 * - `close` and `status` are safe to call from any thread.
 */
private[finagle] class MuxServerNegotiator private (
  refSession: RefPushSession[ByteReader, Buf],
  handle: MuxChannelHandle,
  service: Service[Request, Response],
  makeLocalHeaders: Headers => Headers,
  negotiate: (Service[Request, Response], Option[Headers]) => PushSession[ByteReader, Buf],
  timer: Timer)
    extends PushSession[ByteReader, Buf](handle) {
  import MuxServerNegotiator.log

  private[this] type Phase = Message => Unit

  private[this] val sessionP = Promise[PushSession[ByteReader, Buf]]
  private[this] var handshakePhase: Phase = checkRerrPhase

  handle.onClose.ensure {
    // If we have already completed negotiation, no need to close ourselves
    if (!sessionP.isDefined) close()
  }

  def onClose: Future[Unit] = handle.onClose

  def close(deadline: Time): Future[Unit] = {
    // We want to proxy close calls to the underlying session, provided it resolves in time.
    // This facilitates draining behavior.
    sessionP.by(deadline)(timer).transform {
      case Return(session) => session.close(deadline)
      case Throw(_) => Closable.all(handle, service).close()
    }
  }

  def status: Status = handle.status

  def receive(message: ByteReader): Unit = {
    try {
      val msg = Message.decode(message)
      handshakePhase(msg)
    } catch {
      case NonFatal(t) =>
        close()
        throw t
    } finally message.close()
  }

  private[this] def checkRerrPhase: Phase = {
    case Message.Rerr(Handshake.TinitTag, Handshake.CanTinitMsg) =>
      if (log.isLoggable(Level.DEBUG)) {
        log.debug(s"Received Rerr prelude to Tinit. $remoteAddressString")
      }
      // Prepare to receive the Tinit and send the Rerr reply
      handshakePhase = getInitPhase
      handle.sendAndForget(Message.encode(Message.Rerr(TinitTag, CanTinitMsg)))

    case message => // No negotiation: just init a basic session
      if (log.isLoggable(Level.DEBUG)) {
        log.debug(
          s"Rerr prelude not detected (received ${message.getClass.getSimpleName}. " +
            s"Skipping Init phase. $remoteAddressString"
        )
      }
      noInit(message)
  }

  private[this] def getInitPhase: Phase = {
    case Message.Tinit(tag, Mux.LatestVersion, headers) =>
      try {
        val localHeaders = makeLocalHeaders(headers)
        // We need to enqueue the message *now* so that it makes it down the pipeline before
        // we potentially install the Netty TLS ChannelHandler. If we use the standard
        // `sendAndForget` we bounce the send through the serial executor. If we try to execute
        // the pipeline changes via the continuation provided to `send`, we are not guaranteed
        // that it will happen before handling inbound bytes.

        if (log.isLoggable(Level.TRACE)) {
          log.trace(
            s"Server received client headers $headers; Server sending headers $localHeaders"
          )
        }

        handle.sendNowAndForget(Message.encode(Message.Rinit(tag, Mux.LatestVersion, localHeaders)))

        if (log.isLoggable(Level.TRACE))
          log.trace(s"Server has sent $localHeaders")

        // pipeline changes for Opp-TLS
        val session = negotiate(service, Some(headers))
        negotiationSuccess(session)
      } catch {
        case NonFatal(t) => negotiationFailure(tag, t)
      }

    case Message.Tinit(tag, v, _) =>
      val ex = new IllegalStateException(s"Unsupported Mux version: $v")
      negotiationFailure(tag, ex)

    case message =>
      log.warning(
        "Received Rerr init probe but didn't receive a follow up " +
          s"Tinit (received a ${message.getClass.getSimpleName}). $remoteAddressString"
      )
      noInit(message)
  }

  // helper for logging purposes
  private[this] def remoteAddressString: String = s"remote: ${handle.remoteAddress}"

  private[this] def noInit(message: Message): Unit = {
    try {
      val session = negotiate(service, None)
      // Register the new session and then give it the message
      negotiationSuccess(session)
      session.receive(ByteReader(Message.encode(message)))
    } catch {
      case NonFatal(t) => negotiationFailure(message.tag, t)
    }
  }

  private[this] def negotiationSuccess(session: PushSession[ByteReader, Buf]): Unit = {
    // Upon success, we need to
    // 1. redirect events to the new session
    // 2. satisfy the promise which will trigger any pending `close` calls that occurred
    //    before negotiation completed.
    // The order of these events is not strictly important but its generally a good idea to
    // reshape our pipeline before triggering potentially arbitrary events via satisfying a
    // promise.
    refSession.updateRef(session)
    sessionP.setValue(session)
    handshakePhase = { m: Message =>
      // Should never get here
      val msg = "After negotiation success, more messages were sent to the " +
        s"negotiating session: ${m.getClass.getSimpleName}. $remoteAddressString"
      val ex = new IllegalStateException(msg)
      log.error(ex, msg)
      handle.close()
    }
  }

  private[this] def negotiationFailure(tag: Int, t: Throwable): Unit = {
    // Negotiation failed, so we need to cleanup and shutdown.
    log.warning(t, s"Negotiation failed. Closing session. $remoteAddressString")

    // This is not necessarily unexpected if the peer thinks the session is now open
    handshakePhase = { msg: Message =>
      if (log.isLoggable(Level.DEBUG)) {
        log.debug(
          t,
          s"Session negotiation failed. Swallowing message " +
            s"(${msg.getClass.getSimpleName}) that raced shutdown. $remoteAddressString"
        )
      }
    }

    handle.send(Message.encode((Message.Rerr(tag, t.getMessage)))) { _ =>
      close().ensure {
        sessionP.setException(t)
      }
    }
  }
}

private[finagle] object MuxServerNegotiator {
  private val log = Logger.get

  /** Make a new `MuxServerNegotiator` and install it into the passed `RefPushSession` */
  def build(
    ref: RefPushSession[ByteReader, Buf],
    handle: MuxChannelHandle,
    service: Service[Request, Response],
    makeLocalHeaders: Headers => Headers,
    negotiate: (Service[Request, Response], Option[Headers]) => PushSession[ByteReader, Buf],
    timer: Timer
  ): Unit = {
    val negotiator =
      new MuxServerNegotiator(ref, handle, service, makeLocalHeaders, negotiate, timer)
    ref.updateRef(negotiator)
  }
}
